#! /usr/bin/env julia
# ---------------------------------------------------------


# ---------------------------------------------------------
using CSV, FileIO
using DataFrames
using Parameters
using Statistics: mean, std, cor, I, kron, UpperTriangular, median
using JLD2
using Printf

push!(LOAD_PATH, "./src")
using project_routines
# ---------------------------------------------------------


# ---------------------------------------------------------
# logging
run(`mkdir -p log`)
using Logging, LoggingExtras, Dates

# clean up the files
map(x->rm(x, force=true), ["./log/table9_info.log.jl", "./log/table9_warn.log.jl", "./log/table9_full.log.jl"]);
# custom timestamps
const log_date_format = "yyyy-mm-dd HH:MM:SS"
timestamp_logger(logger) = TransformerLogger(logger) do log
  merge(log, (; message = "$(Dates.format(now(), log_date_format)) \n$(log.message)"))
end

# create the debugger
demux_logger = TeeLogger(
    MinLevelLogger(FileLogger("./log/table9_info.log.jl"), Logging.Info),
    MinLevelLogger(FileLogger("./log/table9_warn.log.jl"), Logging.Warn),
    MinLevelLogger(FileLogger("./log/table9_full.log.jl"), Logging.Debug),
    ConsoleLogger(stdout, Logging.Info),   # Common logger to be able to see info and warning in repl
) |> timestamp_logger |> global_logger;
# ---------------------------------------------------------


# ---------------------------------------------------------
# LOAD DATA AND PARAMETERS
calibration_parameters = CSV.read("./input/param_calibration.csv", DataFrame);
eta, psi, sigma, gamma, phi, years, dt, BURN_IN =  Vector(calibration_parameters[1, :]);
N = 44;         # Number of countries
zj = ones(N,1); # Average productivity

country_names = CSV.read("./input/country_names.csv", DataFrame)

# Country size relative to the US
Lj = CSV.read("./input/Ljstar.csv", DataFrame).Lj # Normalize GDP relative to US GDP

# Import shares 
# Row - country j (base); column - country i (exporter to j)
mij_data = CSV.read("/Users/loulou/Downloads/import_share_matrix.csv", DataFrame);
mij_data = mij_data[:,3:end] |> Matrix |> (x->tryparse.(Float64, x))
mij_data[ isnothing.(mij_data) ] .= 1e-6; # Replace NaN
# mij_data
# ---------------------------------------------------------


# ---------------------------------------------------------
# Structural estimation of the trade costs matrix

# --- Initialize
τ_mat = ones(N,N) * 3.0;
τ_mat[ I(N) .== 1 ] .= 1.0;

pj = ones(N, 1);
H = ones(N, 1);
S = zj' .* Lj'

dist = 1;
count = 0;
countmax = 900;

# --- newton method
i = 0;
for iter in 1:countmax
  # get the model import shares
  while maximum(abs.(H)) > 1E-6
    R = (τ_mat .* pj' ./ pj).^(1-eta)
    global H = sum( R .* (pj ./ pj') ./ sum(R, dims=2) .* S', dims=1) .- S
    pj = pj .* (1.0 .+ 0.3 .* H')
  end
  global pj = pj ./ pj[44]; # Normalize to US price/wage level
  mij_share = (τ_mat .* pj').^(1-eta) ./ sum((τ_mat .* pj').^(1-eta), dims=2);
    
  # 2. Get difference between model and data import share
  mij_diff = mij_data .- mij_share;
    
  # 3. Adjust Tau given mij_diff
  global τ_mat = τ_mat - mij_diff*1;
  τ_mat[ I(N) .== 1 ] .= 1.0;
  τ_mat[ τ_mat.<1 ] .= 1.0;
  τ_mat = triangle_inequality(τ_mat); 
  dist = mean(abs.(mij_diff[ I(N) .!= 1 ]));
end
# ---------------------------------------------------------


# ---------------------------------------------------------
# Get the implied costs of trade from the matrix of trade

# --- inward resistance
pij = τ_mat .* repeat(pj', N, 1)
Pj = sum( pij.^(1-eta), dims=2).^(1/(1-eta));

# --- outward resistance
wj = pj .* zj;
PIj = (wj .* Lj).^(1/(1-eta)) ./ pj;

# --- Print out of the table
# average import trade cost; average export trade cost
# inward multilateral resistance; outward multilateral resistance
df_τ = hcat(
  country_names,
  DataFrame(
    avg_import_costs = mean(τ_mat .- 1.0, dims=2)[:],
    avg_export_costs = mean(τ_mat .- 1.0, dims=1)'[:],
    inward_resistance = Pj[:],
    outward_resistance = PIj[:]
    )
  )
# ---------------------------------------------------------


# ---------------------------------------------------------
# Generate LaTeX Table
table_header = string(
  "\n\n", 
  "% Table generated by julia code table9.jl",
  "\n\n",
  "\\begin{tabular}{lcccc} \n",
  "\\toprule \n",
  raw"    & Average Import & Average Export & Inward & Outward \\ ", "\n",
  raw"    & Trade Costs & Trade Costs & Resistance & Resistance \\ ", "\n",
  "\\midrule \n\n",
  )

table_body = ""
for table_iter in eachrow(df_τ)
  global table_body *= string(
    table_iter.Fullname, " & ", 
    round(table_iter.avg_import_costs, digits=2), " & ", 
    round(table_iter.avg_export_costs, digits=2), " & ", 
    round(table_iter.inward_resistance, digits=2), " & ", 
    round(table_iter.outward_resistance, digits=2),  raw" \\ ", 
    "\n"
    )
end  
table_body

table_footer = string(
  "\n\n", 
  "\\midrule \n",
  "Mean & ", 
     round(mean(df_τ.avg_import_costs), digits=2), " & ",  round(mean(df_τ.avg_import_costs), digits=2), " & ",
     round(mean(df_τ.inward_resistance), digits=2), " & ",  round(mean(df_τ.outward_resistance), digits=2), raw" \\ ", "\n",
  "Standard Deviation & ", 
    round(std(df_τ.avg_import_costs), digits=2), " & ",  round(std(df_τ.avg_import_costs), digits=2), " & ",
    round(std(df_τ.inward_resistance), digits=2), " & ",  round(std(df_τ.outward_resistance), digits=2), raw" \\ ", "\n",
  "\\bottomrule \n",
  "\\end{tabular} \n\n"
  )

f = open("./output/table_9.tex", "w");
      write(f, table_header);
      write(f, table_body);
      write(f, table_footer);
close(f);
# ---------------------------------------------------------


